import re
import os
import numpy as np
from sklearn.externals import joblib
from scipy.sparse import csr_matrix, save_npz

# 2::3124::Agnes Browne (1999)::Comedy|Drama::5493::M::35::12
TRAIN_DATA_PATH = "../../data/ml-1m_20190508/TRAIN_20190506.csv"
TEST_DATA_PATH = "../../data/ml-1m_20190508/TEST_20190506.csv"
SOURCE_OH_MODEL_PATH = "../../data/ml-1m_20190508/oh_encoder_20190509.1"
GEN_DATA_CONFIG = "../../data/ml-1m_20190508/CONFIG_2019050901_gen.csv"
GEN_DATA_TRAIN_X_PATH = "../../data/ml-1m_20190508/TRAIN_X_gen_2019050901.npz"
GEN_DATA_TRAIN_Y_PATH = "../../data/ml-1m_20190508/TRAIN_Y_gen_2019050901.csv"
GEN_DATA_TEST_X_PATH = "../../data/ml-1m_20190508/TEST_X_gen_2019050901.npz"
GEN_DATA_TEST_Y_PATH = "../../data/ml-1m_20190508/TEST_Y_gen_2019050901.csv"


def get_year(movie):
    p1 = re.compile(r'[(](.*?)[)]', re.S)
    res = re.findall(p1, movie)
    if len(res) > 0:
        return int(res[-1].strip())
    else:
        return 0


def get_movie_type_oh(info, oher):
    music_type_list = info.split("|")
    music_type_oh = [oher.transform([[item]])[0] for item in music_type_list]
    res = [0 for i in range(len(music_type_oh[0]))]
    for item in music_type_oh:
        item_list = item.tolist()
        for idx in range(len(item_list)):
            res[idx] = res[idx] + item_list[idx]
    return np.array(res)


def sparse_list(array_get):
    col = []
    data = []
    for idx in range(len(array_get)):
        if array_get[idx] != 0:
            col.append(idx)
            data.append(array_get[idx])
    return col, data


def gen_res(source_data, oh_encoder):
    col_all = []
    row_all = []
    data_all = []
    idx = 0
    y_res = []
    with open(source_data, encoding="utf8") as f:
        for line in f:
            if idx == 0:
                idx = 1
                continue
            ll = line.strip().split("::")
            data_item = []
            scores_item = []
            scores_item = scores_item + oh_encoder["scores"].transform([[ll[0]]])[0].tolist()
            data_item = data_item + oh_encoder["movie_id"].transform([[ll[1]]])[0].tolist()
            data_item = data_item + oh_encoder["movie_year"].transform([[get_year(ll[2])]])[0].tolist()
            data_item = data_item + get_movie_type_oh(ll[3], oh_encoder["movie_type"]).tolist()
            data_item = data_item + oh_encoder["user_id"].transform([[ll[4]]])[0].tolist()
            data_item = data_item + oh_encoder["user_gentle"].transform([[ll[5]]])[0].tolist()
            data_item = data_item + oh_encoder["user_age"].transform([[ll[6]]])[0].tolist()
            data_item = data_item + oh_encoder["user_occupation"].transform([[ll[7]]])[0].tolist()
            # Y处理
            y_res.append(scores_item)
            # X处理
            col, data = sparse_list(data_item)
            col_all = col_all + col
            row_all = row_all + [idx - 1 for item in range(len(col))]
            data_all = data_all + data
            idx = idx + 1
            if idx % 10000 == 0:
                print("generating %s data items" % (idx))
                # break
    x_res = csr_matrix((data_all, (row_all, col_all)), shape=(max(row_all) + 1, 9831))
    return x_res, y_res


# 加载所有one-hot模型
list_oh_path = os.listdir(SOURCE_OH_MODEL_PATH)
oh_encoder = {}

for i in list_oh_path:
    oh_encoder[i[:-6]] = joblib.load("%s/%s" % (SOURCE_OH_MODEL_PATH, i))

features = []
features = features + ["scores:" + str(item) for item in oh_encoder["scores"].categories_[0].tolist()]
features = features + ["movie_id:" + str(item) for item in oh_encoder["movie_id"].categories_[0].tolist()]
features = features + ["movie_year:" + str(item) for item in oh_encoder["movie_year"].categories_[0].tolist()]
features = features + ["movie_type:" + item for item in oh_encoder["movie_type"].categories_[0].tolist()]
features = features + ["user_id:" + str(item) for item in oh_encoder["user_id"].categories_[0].tolist()]
features = features + ["user_gentle:" + item for item in oh_encoder["user_gentle"].categories_[0].tolist()]
features = features + ["user_age:" + str(item) for item in oh_encoder["user_age"].categories_[0].tolist()]
features = features + ["user_occupation:" + str(item) for item in oh_encoder["user_occupation"].categories_[0].tolist()]

# 特征工程配置文件
with open(GEN_DATA_CONFIG, "w", encoding="utf8") as f:
    for line in features:
        f.write("%s\n" % line)

# 训练数据生成
print("generating training data")
x_train, y_train = gen_res(TRAIN_DATA_PATH, oh_encoder)
with open(GEN_DATA_TRAIN_Y_PATH, "w", encoding='utf8') as f:
    for item in y_train:
        f.write("%s\n" % (",".join([str(i) for i in item])))
save_npz(GEN_DATA_TRAIN_X_PATH, x_train)
print("training data generation done")

# 测试数据生成
print("generating testing data")
x_train, y_train = gen_res(TEST_DATA_PATH, oh_encoder)
with open(GEN_DATA_TEST_Y_PATH, "w", encoding='utf8') as f:
    for item in y_train:
        f.write("%s\n" % (",".join([str(i) for i in item])))
save_npz(GEN_DATA_TEST_X_PATH, x_train)
print("testing data generation done")